#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Tests for textio module."""
# pytype: skip-file

import bz2
import glob
import gzip
import logging
import os
import platform
import re
import shutil
import tempfile
import unittest
import zlib
from datetime import datetime

import pytz

import apache_beam as beam
from apache_beam import coders
from apache_beam.io import iobase
from apache_beam.io import source_test_utils
from apache_beam.io.filesystem import CompressionTypes
from apache_beam.io.filesystems import FileSystems
# Importing following private classes for testing.
from apache_beam.io.textio import ReadAllFromText
from apache_beam.io.textio import ReadAllFromTextContinuously
from apache_beam.io.textio import ReadFromText
from apache_beam.io.textio import ReadFromTextWithFilename
from apache_beam.io.textio import WriteToText
from apache_beam.io.textio import _TextSink as TextSink
from apache_beam.io.textio import _TextSource as TextSource
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.test_utils import TempDir
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.core import Create
from apache_beam.transforms.userstate import CombiningValueStateSpec
from apache_beam.transforms.util import LogElements
from apache_beam.utils.timestamp import Timestamp


class DummyCoder(coders.Coder):
  def encode(self, x):
    raise ValueError

  def decode(self, x):
    return (x * 2).decode('utf-8')

  def to_type_hint(self):
    return str


class EOL(object):
  LF = 1
  CRLF = 2
  MIXED = 3
  LF_WITH_NOTHING_AT_LAST_LINE = 4
  CUSTOM_DELIMITER = 5


def write_data(
    num_lines,
    no_data=False,
    directory=None,
    prefix=tempfile.template,
    eol=EOL.LF,
    custom_delimiter=None,
    line_value=b'line'):
  """Writes test data to a temporary file.

  Args:
    num_lines (int): The number of lines to write.
    no_data (bool): If :data:`True`, empty lines will be written, otherwise
      each line will contain a concatenation of b'line' and the line number.
    directory (str): The name of the directory to create the temporary file in.
    prefix (str): The prefix to use for the temporary file.
    eol (int): The line ending to use when writing.
      :class:`~apache_beam.io.textio_test.EOL` exposes attributes that can be
      used here to define the eol.
    custom_delimiter (bytes): The custom delimiter.
    line_value (bytes): Default value for test data, default b'line'

  Returns:
    Tuple[str, List[str]]: A tuple of the filename and a list of the
      utf-8 decoded written data.
  """
  all_data = []
  with tempfile.NamedTemporaryFile(delete=False, dir=directory,
                                   prefix=prefix) as f:
    sep_values = [b'\n', b'\r\n']
    for i in range(num_lines):
      data = b'' if no_data else line_value + str(i).encode()
      all_data.append(data)

      if eol == EOL.LF:
        sep = sep_values[0]
      elif eol == EOL.CRLF:
        sep = sep_values[1]
      elif eol == EOL.MIXED:
        sep = sep_values[i % len(sep_values)]
      elif eol == EOL.LF_WITH_NOTHING_AT_LAST_LINE:
        sep = b'' if i == (num_lines - 1) else sep_values[0]
      elif eol == EOL.CUSTOM_DELIMITER:
        if custom_delimiter is None or len(custom_delimiter) == 0:
          raise ValueError('delimiter can not be null or empty')
        else:
          sep = custom_delimiter
      else:
        raise ValueError('Received unknown value %s for eol.' % eol)

      f.write(data + sep)

    return f.name, [line.decode('utf-8') for line in all_data]


def write_pattern(lines_per_file, no_data=False, return_filenames=False):
  """Writes a pattern of temporary files.

  Args:
    lines_per_file (List[int]): The number of lines to write per file.
    no_data (bool): If :data:`True`, empty lines will be written, otherwise
      each line will contain a concatenation of b'line' and the line number.
    return_filenames (bool): If True, returned list will contain
      (filename, data) pairs.

  Returns:
    Tuple[str, List[Union[str, (str, str)]]]: A tuple of the filename pattern
      and a list of the utf-8 decoded written data or (filename, data) pairs.
  """
  temp_dir = tempfile.mkdtemp()

  all_data = []
  file_name = None
  start_index = 0
  for i in range(len(lines_per_file)):
    file_name, data = write_data(lines_per_file[i], no_data=no_data,
                                 directory=temp_dir, prefix='mytemp')
    if return_filenames:
      all_data.extend(zip([file_name] * len(data), data))
    else:
      all_data.extend(data)
    start_index += lines_per_file[i]

  assert file_name
  return (
      file_name[:file_name.rfind(os.path.sep)] + os.path.sep + 'mytemp*',
      all_data)


class TextSourceTest(unittest.TestCase):

  # Number of records that will be written by most tests.
  DEFAULT_NUM_RECORDS = 100

  def _run_read_test(
      self,
      file_or_pattern,
      expected_data,
      buffer_size=DEFAULT_NUM_RECORDS,
      compression=CompressionTypes.UNCOMPRESSED,
      delimiter=None,
      escapechar=None):
    # Since each record usually takes more than 1 byte, default buffer size is
    # smaller than the total size of the file. This is done to
    # increase test coverage for cases that hit the buffer boundary.
    kwargs = {}
    if delimiter:
      kwargs['delimiter'] = delimiter
    if escapechar:
      kwargs['escapechar'] = escapechar
    source = TextSource(
        file_or_pattern,
        0,
        compression,
        True,
        coders.StrUtf8Coder(),
        buffer_size,
        **kwargs)
    range_tracker = source.get_range_tracker(None, None)
    read_data = list(source.read(range_tracker))
    self.assertCountEqual(expected_data, read_data)

  @unittest.skipIf(platform.system() == 'Windows', 'Skipping on Windows')
  def test_read_from_text_file_pattern_with_dot_slash(self):
    cwd = os.getcwd()
    expected = ['abc', 'de']
    with TempDir() as temp_dir:
      temp_dir.create_temp_file(suffix='.txt', lines=[b'a', b'b', b'c'])
      temp_dir.create_temp_file(suffix='.txt', lines=[b'd', b'e'])

      os.chdir(temp_dir.get_path())
      with TestPipeline() as p:
        dot_slash = p | 'ReadDotSlash' >> ReadFromText('./*.txt')
        no_dot_slash = p | 'ReadNoSlash' >> ReadFromText('*.txt')

        assert_that(dot_slash, equal_to(expected))
        assert_that(no_dot_slash, equal_to(expected))
      os.chdir(cwd)

  def test_read_from_text_with_value_provider(self):
    class UserDefinedOptions(PipelineOptions):
      @classmethod
      def _add_argparse_args(cls, parser):
        parser.add_value_provider_argument(
            '--file_pattern',
            help='This keyword argument is a value provider',
            default='some value')

    options = UserDefinedOptions(['--file_pattern', 'abc'])
    with self.assertRaises(OSError):
      with TestPipeline(options=options) as pipeline:
        _ = pipeline | 'Read' >> ReadFromText(options.file_pattern)

  def test_read_single_file(self):
    file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS)
    assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
    self._run_read_test(file_name, expected_data)

  def test_read_single_file_smaller_than_default_buffer(self):
    file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS)
    self._run_read_test(
        file_name,
        expected_data,
        buffer_size=TextSource.DEFAULT_READ_BUFFER_SIZE)

  def test_read_single_file_larger_than_default_buffer(self):
    file_name, expected_data = write_data(TextSource.DEFAULT_READ_BUFFER_SIZE)
    self._run_read_test(
        file_name,
        expected_data,
        buffer_size=TextSource.DEFAULT_READ_BUFFER_SIZE)

  def test_read_file_pattern(self):
    pattern, expected_data = write_pattern(
        [TextSourceTest.DEFAULT_NUM_RECORDS * 5,
         TextSourceTest.DEFAULT_NUM_RECORDS * 3,
         TextSourceTest.DEFAULT_NUM_RECORDS * 12,
         TextSourceTest.DEFAULT_NUM_RECORDS * 8,
         TextSourceTest.DEFAULT_NUM_RECORDS * 8,
         TextSourceTest.DEFAULT_NUM_RECORDS * 4])
    assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS * 40
    self._run_read_test(pattern, expected_data)

  def test_read_single_file_windows_eol(self):
    file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS,
                                          eol=EOL.CRLF)
    assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
    self._run_read_test(file_name, expected_data)

  def test_read_single_file_mixed_eol(self):
    file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS,
                                          eol=EOL.MIXED)
    assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
    self._run_read_test(file_name, expected_data)

  def test_read_single_file_last_line_no_eol(self):
    file_name, expected_data = write_data(
        TextSourceTest.DEFAULT_NUM_RECORDS,
        eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)
    assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
    self._run_read_test(file_name, expected_data)

  def test_read_single_file_single_line_no_eol(self):
    file_name, expected_data = write_data(
        1, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)

    assert len(expected_data) == 1
    self._run_read_test(file_name, expected_data)

  def test_read_empty_single_file(self):
    file_name, written_data = write_data(
        1, no_data=True, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)

    assert len(written_data) == 1
    # written data has a single entry with an empty string. Reading the source
    # should not produce anything since we only wrote a single empty string
    # without an end of line character.
    self._run_read_test(file_name, [])

  def test_read_single_file_last_line_no_eol_gzip(self):
    file_name, expected_data = write_data(
        TextSourceTest.DEFAULT_NUM_RECORDS,
        eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)

    gzip_file_name = file_name + '.gz'
    with open(file_name, 'rb') as src, gzip.open(gzip_file_name, 'wb') as dst:
      dst.writelines(src)

    assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
    self._run_read_test(
        gzip_file_name, expected_data, compression=CompressionTypes.GZIP)

  def test_read_single_file_single_line_no_eol_gzip(self):
    file_name, expected_data = write_data(
        1, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)

    gzip_file_name = file_name + '.gz'
    with open(file_name, 'rb') as src, gzip.open(gzip_file_name, 'wb') as dst:
      dst.writelines(src)

    assert len(expected_data) == 1
    self._run_read_test(
        gzip_file_name, expected_data, compression=CompressionTypes.GZIP)

  def test_read_empty_single_file_no_eol_gzip(self):
    file_name, written_data = write_data(
        1, no_data=True, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)

    gzip_file_name = file_name + '.gz'
    with open(file_name, 'rb') as src, gzip.open(gzip_file_name, 'wb') as dst:
      dst.writelines(src)

    assert len(written_data) == 1
    # written data has a single entry with an empty string. Reading the source
    # should not produce anything since we only wrote a single empty string
    # without an end of line character.
    self._run_read_test(gzip_file_name, [], compression=CompressionTypes.GZIP)

  def test_read_single_file_with_empty_lines(self):
    file_name, expected_data = write_data(
        TextSourceTest.DEFAULT_NUM_RECORDS, no_data=True, eol=EOL.LF)

    assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
    assert not expected_data[0]

    self._run_read_test(file_name, expected_data)

  def test_read_single_file_without_striping_eol_lf(self):
    file_name, written_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS,
                                         eol=EOL.LF)
    assert len(written_data) == TextSourceTest.DEFAULT_NUM_RECORDS
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        False,
        coders.StrUtf8Coder())

    range_tracker = source.get_range_tracker(None, None)
    read_data = list(source.read(range_tracker))
    self.assertCountEqual([line + '\n' for line in written_data], read_data)

  def test_read_single_file_without_striping_eol_crlf(self):
    file_name, written_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS,
                                         eol=EOL.CRLF)
    assert len(written_data) == TextSourceTest.DEFAULT_NUM_RECORDS
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        False,
        coders.StrUtf8Coder())

    range_tracker = source.get_range_tracker(None, None)
    read_data = list(source.read(range_tracker))
    self.assertCountEqual([line + '\r\n' for line in written_data], read_data)

  def test_read_file_pattern_with_empty_files(self):
    pattern, expected_data = write_pattern(
        [5 * TextSourceTest.DEFAULT_NUM_RECORDS,
         3 * TextSourceTest.DEFAULT_NUM_RECORDS,
         12 * TextSourceTest.DEFAULT_NUM_RECORDS,
         8 * TextSourceTest.DEFAULT_NUM_RECORDS,
         8 * TextSourceTest.DEFAULT_NUM_RECORDS,
         4 * TextSourceTest.DEFAULT_NUM_RECORDS],
        no_data=True)
    assert len(expected_data) == 40 * TextSourceTest.DEFAULT_NUM_RECORDS
    assert not expected_data[0]
    self._run_read_test(pattern, expected_data)

  def test_read_after_splitting(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder())
    splits = list(source.split(desired_bundle_size=33))

    reference_source_info = (source, None, None)
    sources_info = ([(split.source, split.start_position, split.stop_position)
                     for split in splits])
    source_test_utils.assert_sources_equal_reference_source(
        reference_source_info, sources_info)

  def test_header_processing(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10

    def header_matcher(line):
      return line in expected_data[:5]

    header_lines = []

    def store_header(lines):
      for line in lines:
        header_lines.append(line)

    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder(),
        header_processor_fns=(header_matcher, store_header))
    splits = list(source.split(desired_bundle_size=100000))
    assert len(splits) == 1
    range_tracker = splits[0].source.get_range_tracker(
        splits[0].start_position, splits[0].stop_position)
    read_data = list(source.read_records(file_name, range_tracker))

    self.assertCountEqual(expected_data[:5], header_lines)
    self.assertCountEqual(expected_data[5:], read_data)

  def test_progress(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder())
    splits = list(source.split(desired_bundle_size=100000))
    assert len(splits) == 1
    fraction_consumed_report = []
    split_points_report = []
    range_tracker = splits[0].source.get_range_tracker(
        splits[0].start_position, splits[0].stop_position)
    for _ in splits[0].source.read(range_tracker):
      fraction_consumed_report.append(range_tracker.fraction_consumed())
      split_points_report.append(range_tracker.split_points())

    self.assertEqual([float(i) / 10 for i in range(0, 10)],
                     fraction_consumed_report)
    expected_split_points_report = [((i - 1),
                                     iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
                                    for i in range(1, 10)]

    # At last split point, the remaining split points callback returns 1 since
    # the expected position of next record becomes equal to the stop position.
    expected_split_points_report.append((9, 1))

    self.assertEqual(expected_split_points_report, split_points_report)

  def test_read_reentrant_without_splitting(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder())
    source_test_utils.assert_reentrant_reads_succeed((source, None, None))

  def test_read_reentrant_after_splitting(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder())
    splits = list(source.split(desired_bundle_size=100000))
    assert len(splits) == 1
    source_test_utils.assert_reentrant_reads_succeed(
        (splits[0].source, splits[0].start_position, splits[0].stop_position))

  def test_dynamic_work_rebalancing(self):
    file_name, expected_data = write_data(5)
    assert len(expected_data) == 5
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder())
    splits = list(source.split(desired_bundle_size=100000))
    assert len(splits) == 1
    source_test_utils.assert_split_at_fraction_exhaustive(
        splits[0].source, splits[0].start_position, splits[0].stop_position)

  def test_dynamic_work_rebalancing_windows_eol(self):
    file_name, expected_data = write_data(15, eol=EOL.CRLF)
    assert len(expected_data) == 15
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder())
    splits = list(source.split(desired_bundle_size=100000))
    assert len(splits) == 1
    source_test_utils.assert_split_at_fraction_exhaustive(
        splits[0].source,
        splits[0].start_position,
        splits[0].stop_position,
        perform_multi_threaded_test=False)

  def test_dynamic_work_rebalancing_mixed_eol(self):
    file_name, expected_data = write_data(5, eol=EOL.MIXED)
    assert len(expected_data) == 5
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder())
    splits = list(source.split(desired_bundle_size=100000))
    assert len(splits) == 1
    source_test_utils.assert_split_at_fraction_exhaustive(
        splits[0].source,
        splits[0].start_position,
        splits[0].stop_position,
        perform_multi_threaded_test=False)

  def test_read_from_text_single_file(self):
    file_name, expected_data = write_data(5)
    assert len(expected_data) == 5
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> ReadFromText(file_name)
      assert_that(pcoll, equal_to(expected_data))

  def test_read_from_text_with_file_name_single_file(self):
    file_name, data = write_data(5)
    expected_data = [(file_name, el) for el in data]
    assert len(expected_data) == 5
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> ReadFromTextWithFilename(file_name)
      assert_that(pcoll, equal_to(expected_data))

  def test_read_all_single_file(self):
    file_name, expected_data = write_data(5)
    assert len(expected_data) == 5
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Create' >> Create(
          [file_name]) | 'ReadAll' >> ReadAllFromText()
      assert_that(pcoll, equal_to(expected_data))

  def test_read_all_many_single_files(self):
    file_name1, expected_data1 = write_data(5)
    assert len(expected_data1) == 5
    file_name2, expected_data2 = write_data(10)
    assert len(expected_data2) == 10
    file_name3, expected_data3 = write_data(15)
    assert len(expected_data3) == 15
    expected_data = []
    expected_data.extend(expected_data1)
    expected_data.extend(expected_data2)
    expected_data.extend(expected_data3)
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Create' >> Create([
          file_name1, file_name2, file_name3
      ]) | 'ReadAll' >> ReadAllFromText()
      assert_that(pcoll, equal_to(expected_data))

  def test_read_all_unavailable_files_ignored(self):
    file_name1, expected_data1 = write_data(5)
    assert len(expected_data1) == 5
    file_name2, expected_data2 = write_data(10)
    assert len(expected_data2) == 10
    file_name3, expected_data3 = write_data(15)
    assert len(expected_data3) == 15
    file_name4 = "/unavailable_file"
    expected_data = []
    expected_data.extend(expected_data1)
    expected_data.extend(expected_data2)
    expected_data.extend(expected_data3)
    with TestPipeline() as pipeline:
      pcoll = (
          pipeline
          | 'Create' >> Create([file_name1, file_name2, file_name3, file_name4])
          | 'ReadAll' >> ReadAllFromText())
      assert_that(pcoll, equal_to(expected_data))

  class _WriteFilesFn(beam.DoFn):
    """writes a couple of files with deferral."""
    COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)

    def __init__(self, temp_path):
      self.temp_path = temp_path

    def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
      counter = count_state.read()
      if counter == 0:
        count_state.add(1)
        with open(FileSystems.join(self.temp_path, 'file1'), 'w') as f:
          f.write('second A\nsecond B')
        with open(FileSystems.join(self.temp_path, 'file2'), 'w') as f:
          f.write('first')
      # convert dumb key to basename in output
      basename = FileSystems.split(element[1][0])[1]
      content = element[1][1]
      yield basename, content

  def test_read_all_continuously_new(self):
    with TempDir() as tempdir, TestPipeline() as pipeline:
      temp_path = tempdir.get_path()
      # create a temp file at the beginning
      with open(FileSystems.join(temp_path, 'file1'), 'w') as f:
        f.write('first')
      match_pattern = FileSystems.join(temp_path, '*')
      interval = 0.5
      last = 2
      p_read_once = (
          pipeline
          | 'Continuously read new files' >> ReadAllFromTextContinuously(
              match_pattern,
              with_filename=True,
              start_timestamp=Timestamp.now(),
              interval=interval,
              stop_timestamp=Timestamp.now() + last,
              match_updated_files=False)
          | 'add dumb key' >> beam.Map(lambda x: (0, x))
          |
          'Write files on-the-fly' >> beam.ParDo(self._WriteFilesFn(temp_path)))
      assert_that(
          p_read_once,
          equal_to([('file1', 'first'), ('file2', 'first')]),
          label='assert read new files results')

  def test_read_all_continuously_update(self):
    with TempDir() as tempdir, TestPipeline() as pipeline:
      temp_path = tempdir.get_path()
      # create a temp file at the beginning
      with open(FileSystems.join(temp_path, 'file1'), 'w') as f:
        f.write('first')
      match_pattern = FileSystems.join(temp_path, '*')
      interval = 0.5
      last = 2
      p_read_upd = (
          pipeline
          | 'Continuously read updated files' >> ReadAllFromTextContinuously(
              match_pattern,
              with_filename=True,
              start_timestamp=Timestamp.now(),
              interval=interval,
              stop_timestamp=Timestamp.now() + last,
              match_updated_files=True)
          | 'add dumb key' >> beam.Map(lambda x: (0, x))
          |
          'Write files on-the-fly' >> beam.ParDo(self._WriteFilesFn(temp_path)))
      assert_that(
          p_read_upd,
          equal_to([('file1', 'first'), ('file1', 'second A'),
                    ('file1', 'second B'), ('file2', 'first')]),
          label='assert read updated files results')

  def test_read_from_text_single_file_with_coder(self):
    file_name, expected_data = write_data(5)
    assert len(expected_data) == 5
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> ReadFromText(file_name, coder=DummyCoder())
      assert_that(pcoll, equal_to([record * 2 for record in expected_data]))

  def test_read_from_text_file_pattern(self):
    pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
    assert len(expected_data) == 40
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> ReadFromText(pattern)
      assert_that(pcoll, equal_to(expected_data))

  def test_read_from_text_with_file_name_file_pattern(self):
    pattern, expected_data = write_pattern(
        lines_per_file=[5, 5], return_filenames=True)
    assert len(expected_data) == 10
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> ReadFromTextWithFilename(pattern)
      assert_that(pcoll, equal_to(expected_data))

  def test_read_all_file_pattern(self):
    pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
    assert len(expected_data) == 40
    with TestPipeline() as pipeline:
      pcoll = (
          pipeline
          | 'Create' >> Create([pattern])
          | 'ReadAll' >> ReadAllFromText())
      assert_that(pcoll, equal_to(expected_data))

  def test_read_all_many_file_patterns(self):
    pattern1, expected_data1 = write_pattern([5, 3, 12, 8, 8, 4])
    assert len(expected_data1) == 40
    pattern2, expected_data2 = write_pattern([3, 7, 9])
    assert len(expected_data2) == 19
    pattern3, expected_data3 = write_pattern([11, 20, 5, 5])
    assert len(expected_data3) == 41
    expected_data = []
    expected_data.extend(expected_data1)
    expected_data.extend(expected_data2)
    expected_data.extend(expected_data3)
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Create' >> Create(
          [pattern1, pattern2, pattern3]) | 'ReadAll' >> ReadAllFromText()
      assert_that(pcoll, equal_to(expected_data))

  def test_read_all_with_filename(self):
    pattern, expected_data = write_pattern([5, 3], return_filenames=True)
    assert len(expected_data) == 8

    with TestPipeline() as pipeline:
      pcoll = (
          pipeline
          | 'Create' >> Create([pattern])
          | 'ReadAll' >> ReadAllFromText(with_filename=True))
      assert_that(pcoll, equal_to(expected_data))

  def test_read_auto_bzip2(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file(suffix='.bz2')
      with bz2.BZ2File(file_name, 'wb') as f:
        f.write('\n'.join(lines).encode('utf-8'))

      with TestPipeline() as pipeline:
        pcoll = pipeline | 'Read' >> ReadFromText(file_name)
        assert_that(pcoll, equal_to(lines))

  def test_read_auto_deflate(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file(suffix='.deflate')
      with open(file_name, 'wb') as f:
        f.write(zlib.compress('\n'.join(lines).encode('utf-8')))

      with TestPipeline() as pipeline:
        pcoll = pipeline | 'Read' >> ReadFromText(file_name)
        assert_that(pcoll, equal_to(lines))

  def test_read_auto_gzip(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file(suffix='.gz')

      with gzip.GzipFile(file_name, 'wb') as f:
        f.write('\n'.join(lines).encode('utf-8'))

      with TestPipeline() as pipeline:
        pcoll = pipeline | 'Read' >> ReadFromText(file_name)
        assert_that(pcoll, equal_to(lines))

  def test_read_bzip2(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with bz2.BZ2File(file_name, 'wb') as f:
        f.write('\n'.join(lines).encode('utf-8'))

      with TestPipeline() as pipeline:
        pcoll = pipeline | 'Read' >> ReadFromText(
            file_name, compression_type=CompressionTypes.BZIP2)
        assert_that(pcoll, equal_to(lines))

  def test_read_corrupted_bzip2_fails(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with bz2.BZ2File(file_name, 'wb') as f:
        f.write('\n'.join(lines).encode('utf-8'))

      with open(file_name, 'wb') as f:
        f.write(b'corrupt')

      with self.assertRaises(Exception):
        with TestPipeline() as pipeline:
          pcoll = pipeline | 'Read' >> ReadFromText(
              file_name, compression_type=CompressionTypes.BZIP2)
          assert_that(pcoll, equal_to(lines))

  def test_read_bzip2_concat(self):
    with TempDir() as tempdir:
      bzip2_file_name1 = tempdir.create_temp_file()
      lines = ['a', 'b', 'c']
      with bz2.BZ2File(bzip2_file_name1, 'wb') as dst:
        data = '\n'.join(lines) + '\n'
        dst.write(data.encode('utf-8'))

      bzip2_file_name2 = tempdir.create_temp_file()
      lines = ['p', 'q', 'r']
      with bz2.BZ2File(bzip2_file_name2, 'wb') as dst:
        data = '\n'.join(lines) + '\n'
        dst.write(data.encode('utf-8'))

      bzip2_file_name3 = tempdir.create_temp_file()
      lines = ['x', 'y', 'z']
      with bz2.BZ2File(bzip2_file_name3, 'wb') as dst:
        data = '\n'.join(lines) + '\n'
        dst.write(data.encode('utf-8'))

      final_bzip2_file = tempdir.create_temp_file()
      with open(bzip2_file_name1, 'rb') as src, open(
          final_bzip2_file, 'wb') as dst:
        dst.writelines(src.readlines())

      with open(bzip2_file_name2, 'rb') as src, open(
          final_bzip2_file, 'ab') as dst:
        dst.writelines(src.readlines())

      with open(bzip2_file_name3, 'rb') as src, open(
          final_bzip2_file, 'ab') as dst:
        dst.writelines(src.readlines())

      with TestPipeline() as pipeline:
        lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
            final_bzip2_file,
            compression_type=beam.io.filesystem.CompressionTypes.BZIP2)

        expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
        assert_that(lines, equal_to(expected))

  def test_read_deflate(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with open(file_name, 'wb') as f:
        f.write(zlib.compress('\n'.join(lines).encode('utf-8')))

      with TestPipeline() as pipeline:
        pcoll = pipeline | 'Read' >> ReadFromText(
            file_name, 0, CompressionTypes.DEFLATE, True, coders.StrUtf8Coder())
        assert_that(pcoll, equal_to(lines))

  def test_read_corrupted_deflate_fails(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with open(file_name, 'wb') as f:
        f.write(zlib.compress('\n'.join(lines).encode('utf-8')))

      with open(file_name, 'wb') as f:
        f.write(b'corrupt')

      with self.assertRaises(Exception):
        with TestPipeline() as pipeline:
          pcoll = pipeline | 'Read' >> ReadFromText(
              file_name,
              0,
              CompressionTypes.DEFLATE,
              True,
              coders.StrUtf8Coder())
          assert_that(pcoll, equal_to(lines))

  def test_read_deflate_concat(self):
    with TempDir() as tempdir:
      deflate_file_name1 = tempdir.create_temp_file()
      lines = ['a', 'b', 'c']
      with open(deflate_file_name1, 'wb') as dst:
        data = '\n'.join(lines) + '\n'
        dst.write(zlib.compress(data.encode('utf-8')))

      deflate_file_name2 = tempdir.create_temp_file()
      lines = ['p', 'q', 'r']
      with open(deflate_file_name2, 'wb') as dst:
        data = '\n'.join(lines) + '\n'
        dst.write(zlib.compress(data.encode('utf-8')))

      deflate_file_name3 = tempdir.create_temp_file()
      lines = ['x', 'y', 'z']
      with open(deflate_file_name3, 'wb') as dst:
        data = '\n'.join(lines) + '\n'
        dst.write(zlib.compress(data.encode('utf-8')))

      final_deflate_file = tempdir.create_temp_file()
      with open(deflate_file_name1, 'rb') as src, \
              open(final_deflate_file, 'wb') as dst:
        dst.writelines(src.readlines())

      with open(deflate_file_name2, 'rb') as src, \
              open(final_deflate_file, 'ab') as dst:
        dst.writelines(src.readlines())

      with open(deflate_file_name3, 'rb') as src, \
              open(final_deflate_file, 'ab') as dst:
        dst.writelines(src.readlines())

      with TestPipeline() as pipeline:
        lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
            final_deflate_file,
            compression_type=beam.io.filesystem.CompressionTypes.DEFLATE)

        expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
        assert_that(lines, equal_to(expected))

  def test_read_gzip(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with gzip.GzipFile(file_name, 'wb') as f:
        f.write('\n'.join(lines).encode('utf-8'))

      with TestPipeline() as pipeline:
        pcoll = pipeline | 'Read' >> ReadFromText(
            file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
        assert_that(pcoll, equal_to(lines))

  def test_read_corrupted_gzip_fails(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with gzip.GzipFile(file_name, 'wb') as f:
        f.write('\n'.join(lines).encode('utf-8'))

      with open(file_name, 'wb') as f:
        f.write(b'corrupt')

      with self.assertRaises(Exception):
        with TestPipeline() as pipeline:
          pcoll = pipeline | 'Read' >> ReadFromText(
              file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
          assert_that(pcoll, equal_to(lines))

  def test_read_gzip_concat(self):
    with TempDir() as tempdir:
      gzip_file_name1 = tempdir.create_temp_file()
      lines = ['a', 'b', 'c']
      with gzip.open(gzip_file_name1, 'wb') as dst:
        data = '\n'.join(lines) + '\n'
        dst.write(data.encode('utf-8'))

      gzip_file_name2 = tempdir.create_temp_file()
      lines = ['p', 'q', 'r']
      with gzip.open(gzip_file_name2, 'wb') as dst:
        data = '\n'.join(lines) + '\n'
        dst.write(data.encode('utf-8'))

      gzip_file_name3 = tempdir.create_temp_file()
      lines = ['x', 'y', 'z']
      with gzip.open(gzip_file_name3, 'wb') as dst:
        data = '\n'.join(lines) + '\n'
        dst.write(data.encode('utf-8'))

      final_gzip_file = tempdir.create_temp_file()
      with open(gzip_file_name1, 'rb') as src, \
           open(final_gzip_file, 'wb') as dst:
        dst.writelines(src.readlines())

      with open(gzip_file_name2, 'rb') as src, \
           open(final_gzip_file, 'ab') as dst:
        dst.writelines(src.readlines())

      with open(gzip_file_name3, 'rb') as src, \
           open(final_gzip_file, 'ab') as dst:
        dst.writelines(src.readlines())

      with TestPipeline() as pipeline:
        lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
            final_gzip_file,
            compression_type=beam.io.filesystem.CompressionTypes.GZIP)

        expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
        assert_that(lines, equal_to(expected))

  def test_read_all_gzip(self):
    _, lines = write_data(100)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with gzip.GzipFile(file_name, 'wb') as f:
        f.write('\n'.join(lines).encode('utf-8'))
      with TestPipeline() as pipeline:
        pcoll = (
            pipeline
            | Create([file_name])
            | 'ReadAll' >>
            ReadAllFromText(compression_type=CompressionTypes.GZIP))
        assert_that(pcoll, equal_to(lines))

  def test_read_gzip_large(self):
    _, lines = write_data(10000)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()

      with gzip.GzipFile(file_name, 'wb') as f:
        f.write('\n'.join(lines).encode('utf-8'))

      with TestPipeline() as pipeline:
        pcoll = pipeline | 'Read' >> ReadFromText(
            file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
        assert_that(pcoll, equal_to(lines))

  def test_read_gzip_large_after_splitting(self):
    _, lines = write_data(10000)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with gzip.GzipFile(file_name, 'wb') as f:
        f.write('\n'.join(lines).encode('utf-8'))

      source = TextSource(
          file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
      splits = list(source.split(desired_bundle_size=1000))

      if len(splits) > 1:
        raise ValueError(
            'FileBasedSource generated more than one initial '
            'split for a compressed file.')

      reference_source_info = (source, None, None)
      sources_info = ([
          (split.source, split.start_position, split.stop_position)
          for split in splits
      ])
      source_test_utils.assert_sources_equal_reference_source(
          reference_source_info, sources_info)

  def test_read_gzip_empty_file(self):
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with TestPipeline() as pipeline:
        pcoll = pipeline | 'Read' >> ReadFromText(
            file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
        assert_that(pcoll, equal_to([]))

  def _remove_lines(self, lines, sublist_lengths, num_to_remove):
    """Utility function to remove num_to_remove lines from each sublist.

    Args:
      lines: list of items.
      sublist_lengths: list of integers representing length of sublist
        corresponding to each source file.
      num_to_remove: number of lines to remove from each sublist.
    Returns:
      remaining lines.
    """
    curr = 0
    result = []
    for offset in sublist_lengths:
      end = curr + offset
      start = min(curr + num_to_remove, end)
      result += lines[start:end]
      curr += offset
    return result

  def _read_skip_header_lines(self, file_or_pattern, skip_header_lines):
    """Simple wrapper function for instantiating TextSource."""
    source = TextSource(
        file_or_pattern,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder(),
        skip_header_lines=skip_header_lines)

    range_tracker = source.get_range_tracker(None, None)
    return list(source.read(range_tracker))

  def test_read_skip_header_single(self):
    file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS)
    assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
    skip_header_lines = 1
    expected_data = self._remove_lines(
        expected_data, [TextSourceTest.DEFAULT_NUM_RECORDS], skip_header_lines)
    read_data = self._read_skip_header_lines(file_name, skip_header_lines)
    self.assertEqual(len(expected_data), len(read_data))
    self.assertCountEqual(expected_data, read_data)

  def test_read_skip_header_pattern(self):
    line_counts = [
        TextSourceTest.DEFAULT_NUM_RECORDS * 5,
        TextSourceTest.DEFAULT_NUM_RECORDS * 3,
        TextSourceTest.DEFAULT_NUM_RECORDS * 12,
        TextSourceTest.DEFAULT_NUM_RECORDS * 8,
        TextSourceTest.DEFAULT_NUM_RECORDS * 8,
        TextSourceTest.DEFAULT_NUM_RECORDS * 4
    ]
    skip_header_lines = 2
    pattern, data = write_pattern(line_counts)

    expected_data = self._remove_lines(data, line_counts, skip_header_lines)
    read_data = self._read_skip_header_lines(pattern, skip_header_lines)
    self.assertEqual(len(expected_data), len(read_data))
    self.assertCountEqual(expected_data, read_data)

  def test_read_skip_header_pattern_insufficient_lines(self):
    line_counts = [
        5,
        3,  # Fewer lines in file than we want to skip
        12,
        8,
        8,
        4
    ]
    skip_header_lines = 4
    pattern, data = write_pattern(line_counts)

    data = self._remove_lines(data, line_counts, skip_header_lines)
    read_data = self._read_skip_header_lines(pattern, skip_header_lines)
    self.assertEqual(len(data), len(read_data))
    self.assertCountEqual(data, read_data)

  def test_read_gzip_with_skip_lines(self):
    _, lines = write_data(15)
    with TempDir() as tempdir:
      file_name = tempdir.create_temp_file()
      with gzip.GzipFile(file_name, 'wb') as f:
        f.write('\n'.join(lines).encode('utf-8'))

      with TestPipeline() as pipeline:
        pcoll = pipeline | 'Read' >> ReadFromText(
            file_name,
            0,
            CompressionTypes.GZIP,
            True,
            coders.StrUtf8Coder(),
            skip_header_lines=2)
        assert_that(pcoll, equal_to(lines[2:]))

  def test_read_after_splitting_skip_header(self):
    file_name, expected_data = write_data(100)
    assert len(expected_data) == 100
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder(),
        skip_header_lines=2)
    splits = list(source.split(desired_bundle_size=33))

    reference_source_info = (source, None, None)
    sources_info = ([(split.source, split.start_position, split.stop_position)
                     for split in splits])
    self.assertGreater(len(sources_info), 1)
    reference_lines = source_test_utils.read_from_source(*reference_source_info)
    split_lines = []
    for source_info in sources_info:
      split_lines.extend(source_test_utils.read_from_source(*source_info))

    self.assertEqual(expected_data[2:], reference_lines)
    self.assertEqual(reference_lines, split_lines)

  def test_custom_delimiter_read_from_text(self):
    file_name, expected_data = write_data(
      5, eol=EOL.CUSTOM_DELIMITER, custom_delimiter=b'@#')
    assert len(expected_data) == 5
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Read' >> ReadFromText(file_name, delimiter=b'@#')
      assert_that(pcoll, equal_to(expected_data))

  def test_custom_delimiter_read_all_single_file(self):
    file_name, expected_data = write_data(
      5, eol=EOL.CUSTOM_DELIMITER, custom_delimiter=b'@#')
    assert len(expected_data) == 5
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Create' >> Create(
          [file_name]) | 'ReadAll' >> ReadAllFromText(delimiter=b'@#')
      assert_that(pcoll, equal_to(expected_data))

  def test_invalid_delimiters_are_rejected(self):
    file_name, _ = write_data(1)
    for delimiter in (b'', '', '\r\n', 'a', 1):
      with self.assertRaises(
          ValueError, msg='Delimiter must be a non-empty bytes sequence.'):
        _ = TextSource(
            file_pattern=file_name,
            min_bundle_size=0,
            buffer_size=6,
            compression_type=CompressionTypes.UNCOMPRESSED,
            strip_trailing_newlines=True,
            coder=coders.StrUtf8Coder(),
            delimiter=delimiter,
        )

  def test_non_self_overlapping_delimiter_is_accepted(self):
    file_name, _ = write_data(1)
    for delimiter in (b'\n', b'\r\n', b'*', b'abc', b'cabdab', b'abcabd'):
      _ = TextSource(
          file_pattern=file_name,
          min_bundle_size=0,
          buffer_size=6,
          compression_type=CompressionTypes.UNCOMPRESSED,
          strip_trailing_newlines=True,
          coder=coders.StrUtf8Coder(),
          delimiter=delimiter,
      )

  def test_self_overlapping_delimiter_is_rejected(self):
    file_name, _ = write_data(1)
    for delimiter in (b'||', b'***', b'aba', b'abcab'):
      with self.assertRaises(ValueError,
                             msg='Delimiter must not self-overlap.'):
        _ = TextSource(
            file_pattern=file_name,
            min_bundle_size=0,
            buffer_size=6,
            compression_type=CompressionTypes.UNCOMPRESSED,
            strip_trailing_newlines=True,
            coder=coders.StrUtf8Coder(),
            delimiter=delimiter,
        )

  def test_read_with_customer_delimiter(self):
    delimiters = [
        b'\n',
        b'\r\n',
        b'*|',
        b'*',
        b'*=-',
    ]

    for delimiter in delimiters:
      file_name, expected_data = write_data(
        10,
        eol=EOL.CUSTOM_DELIMITER,
        custom_delimiter=delimiter)

      assert len(expected_data) == 10
      source = TextSource(
          file_pattern=file_name,
          min_bundle_size=0,
          compression_type=CompressionTypes.UNCOMPRESSED,
          strip_trailing_newlines=True,
          coder=coders.StrUtf8Coder(),
          delimiter=delimiter)
      range_tracker = source.get_range_tracker(None, None)
      read_data = list(source.read(range_tracker))

      self.assertEqual(read_data, expected_data)

  def test_read_with_custom_delimiter_around_split_point(self):
    for delimiter in (b'\n', b'\r\n', b'@#', b'abc'):
      file_name, expected_data = write_data(
        20,
        eol=EOL.CUSTOM_DELIMITER,
        custom_delimiter=delimiter)
      assert len(expected_data) == 20
      for desired_bundle_size in (4, 5, 6, 7):
        source = TextSource(
            file_name,
            0,
            CompressionTypes.UNCOMPRESSED,
            True,
            coders.StrUtf8Coder(),
            delimiter=delimiter)
        splits = list(source.split(desired_bundle_size=desired_bundle_size))

        reference_source_info = (source, None, None)
        sources_info = ([
            (split.source, split.start_position, split.stop_position)
            for split in splits
        ])
        source_test_utils.assert_sources_equal_reference_source(
            reference_source_info, sources_info)

  def test_read_with_customer_delimiter_truncated(self):
    """
    Corner case: delimiter truncated at the end of the file
    Use delimiter with length = 3, buffer_size = 6
    and line_value with length = 4
    to split the delimiter
    """
    delimiter = b'@$*'

    file_name, expected_data = write_data(
      10,
      eol=EOL.CUSTOM_DELIMITER,
      line_value=b'a' * 4,
      custom_delimiter=delimiter)

    assert len(expected_data) == 10
    source = TextSource(
        file_pattern=file_name,
        min_bundle_size=0,
        buffer_size=6,
        compression_type=CompressionTypes.UNCOMPRESSED,
        strip_trailing_newlines=True,
        coder=coders.StrUtf8Coder(),
        delimiter=delimiter,
    )
    range_tracker = source.get_range_tracker(None, None)
    read_data = list(source.read(range_tracker))

    self.assertEqual(read_data, expected_data)

  def test_read_with_customer_delimiter_over_buffer_size(self):
    """
    Corner case: delimiter is on border of size of buffer
    """
    file_name, expected_data = write_data(3, eol=EOL.CRLF, line_value=b'\rline')
    assert len(expected_data) == 3
    self._run_read_test(
        file_name, expected_data, buffer_size=7, delimiter=b'\r\n')

  def test_read_with_customer_delimiter_truncated_and_not_equal(self):
    """
    Corner case: delimiter truncated at the end of the file
    and only part of delimiter equal end of buffer

    Use delimiter with length = 3, buffer_size = 6
    and line_value with length = 4
    to split the delimiter
    """

    write_delimiter = b'@$'
    read_delimiter = b'@$*'

    file_name, expected_data = write_data(
      10,
      eol=EOL.CUSTOM_DELIMITER,
      line_value=b'a' * 4,
      custom_delimiter=write_delimiter)

    # In this case check, that the line won't be splitted
    write_delimiter_encode = write_delimiter.decode('utf-8')
    expected_data_str = [
        write_delimiter_encode.join(expected_data) + write_delimiter_encode
    ]

    source = TextSource(
        file_pattern=file_name,
        min_bundle_size=0,
        buffer_size=6,
        compression_type=CompressionTypes.UNCOMPRESSED,
        strip_trailing_newlines=True,
        coder=coders.StrUtf8Coder(),
        delimiter=read_delimiter,
    )
    range_tracker = source.get_range_tracker(None, None)

    read_data = list(source.read(range_tracker))

    self.assertEqual(read_data, expected_data_str)

  def test_read_crlf_split_by_buffer(self):
    file_name, expected_data = write_data(3, eol=EOL.CRLF)
    assert len(expected_data) == 3
    self._run_read_test(file_name, expected_data, buffer_size=6)

  def test_read_escaped_lf(self):
    file_name, expected_data = write_data(
      self.DEFAULT_NUM_RECORDS, eol=EOL.LF, line_value=b'li\\\nne')
    assert len(expected_data) == self.DEFAULT_NUM_RECORDS
    self._run_read_test(file_name, expected_data, escapechar=b'\\')

  def test_read_escaped_crlf(self):
    file_name, expected_data = write_data(
      TextSource.DEFAULT_READ_BUFFER_SIZE,
      eol=EOL.CRLF,
      line_value=b'li\\\r\\\nne')
    assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE
    self._run_read_test(file_name, expected_data, escapechar=b'\\')

  def test_read_escaped_cr_before_not_escaped_lf(self):
    file_name, expected_data_temp = write_data(
      self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne')
    expected_data = []
    for line in expected_data_temp:
      expected_data += line.split("\n")
    assert len(expected_data) == self.DEFAULT_NUM_RECORDS * 2
    self._run_read_test(file_name, expected_data, escapechar=b'\\')

  def test_read_escaped_custom_delimiter_crlf(self):
    file_name, expected_data = write_data(
      self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne')
    assert len(expected_data) == self.DEFAULT_NUM_RECORDS
    self._run_read_test(
        file_name, expected_data, delimiter=b'\r\n', escapechar=b'\\')

  def test_read_escaped_custom_delimiter(self):
    file_name, expected_data = write_data(
      TextSource.DEFAULT_READ_BUFFER_SIZE,
      eol=EOL.CUSTOM_DELIMITER,
      custom_delimiter=b'*|',
      line_value=b'li\\*|ne')
    assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE
    self._run_read_test(
        file_name, expected_data, delimiter=b'*|', escapechar=b'\\')

  def test_read_escaped_lf_at_buffer_edge(self):
    file_name, expected_data = write_data(3, eol=EOL.LF, line_value=b'line\\\n')
    assert len(expected_data) == 3
    self._run_read_test(
        file_name, expected_data, buffer_size=5, escapechar=b'\\')

  def test_read_escaped_crlf_split_by_buffer(self):
    file_name, expected_data = write_data(
      3, eol=EOL.CRLF, line_value=b'line\\\r\n')
    assert len(expected_data) == 3
    self._run_read_test(
        file_name,
        expected_data,
        buffer_size=6,
        delimiter=b'\r\n',
        escapechar=b'\\')

  def test_read_escaped_lf_after_splitting(self):
    file_name, expected_data = write_data(3, line_value=b'line\\\n')
    assert len(expected_data) == 3
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder(),
        escapechar=b'\\')
    splits = list(source.split(desired_bundle_size=6))

    reference_source_info = (source, None, None)
    sources_info = ([(split.source, split.start_position, split.stop_position)
                     for split in splits])
    source_test_utils.assert_sources_equal_reference_source(
        reference_source_info, sources_info)

  def test_read_escaped_lf_after_splitting_many(self):
    file_name, expected_data = write_data(
      3, line_value=b'\\\\\\\\\\\n')  # 5 escapes
    assert len(expected_data) == 3
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder(),
        escapechar=b'\\')
    splits = list(source.split(desired_bundle_size=6))

    reference_source_info = (source, None, None)
    sources_info = ([(split.source, split.start_position, split.stop_position)
                     for split in splits])
    source_test_utils.assert_sources_equal_reference_source(
        reference_source_info, sources_info)

  def test_read_escaped_escapechar_after_splitting(self):
    file_name, expected_data = write_data(3, line_value=b'line\\\\*|')
    assert len(expected_data) == 3
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder(),
        delimiter=b'*|',
        escapechar=b'\\')
    splits = list(source.split(desired_bundle_size=8))

    reference_source_info = (source, None, None)
    sources_info = ([(split.source, split.start_position, split.stop_position)
                     for split in splits])
    source_test_utils.assert_sources_equal_reference_source(
        reference_source_info, sources_info)

  def test_read_escaped_escapechar_after_splitting_many(self):
    file_name, expected_data = write_data(
      3, line_value=b'\\\\\\\\\\\\*|')  # 6 escapes
    assert len(expected_data) == 3
    source = TextSource(
        file_name,
        0,
        CompressionTypes.UNCOMPRESSED,
        True,
        coders.StrUtf8Coder(),
        delimiter=b'*|',
        escapechar=b'\\')
    splits = list(source.split(desired_bundle_size=8))

    reference_source_info = (source, None, None)
    sources_info = ([(split.source, split.start_position, split.stop_position)
                     for split in splits])
    source_test_utils.assert_sources_equal_reference_source(
        reference_source_info, sources_info)


class TextSinkTest(unittest.TestCase):
  def setUp(self):
    super().setUp()
    self.lines = [b'Line %d' % d for d in range(100)]
    self.tempdir = tempfile.mkdtemp()
    self.path = self._create_temp_file()

  def tearDown(self):
    if os.path.exists(self.tempdir):
      shutil.rmtree(self.tempdir)

  def _create_temp_file(self, name='', suffix=''):
    if not name:
      name = tempfile.template
    file_name = tempfile.NamedTemporaryFile(
        delete=True, prefix=name, dir=self.tempdir, suffix=suffix).name
    return file_name

  def _write_lines(self, sink, lines):
    f = sink.open(self.path)
    for line in lines:
      sink.write_record(f, line)
    sink.close(f)

  def test_write_text_file(self):
    sink = TextSink(self.path)
    self._write_lines(sink, self.lines)

    with open(self.path, 'rb') as f:
      self.assertEqual(f.read().splitlines(), self.lines)

  def test_write_text_file_empty(self):
    sink = TextSink(self.path)
    self._write_lines(sink, [])

    with open(self.path, 'rb') as f:
      self.assertEqual(f.read().splitlines(), [])

  def test_write_bzip2_file(self):
    sink = TextSink(self.path, compression_type=CompressionTypes.BZIP2)
    self._write_lines(sink, self.lines)

    with bz2.BZ2File(self.path, 'rb') as f:
      self.assertEqual(f.read().splitlines(), self.lines)

  def test_write_bzip2_file_auto(self):
    self.path = self._create_temp_file(suffix='.bz2')
    sink = TextSink(self.path)
    self._write_lines(sink, self.lines)

    with bz2.BZ2File(self.path, 'rb') as f:
      self.assertEqual(f.read().splitlines(), self.lines)

  def test_write_gzip_file(self):
    sink = TextSink(self.path, compression_type=CompressionTypes.GZIP)
    self._write_lines(sink, self.lines)

    with gzip.GzipFile(self.path, 'rb') as f:
      self.assertEqual(f.read().splitlines(), self.lines)

  def test_write_gzip_file_auto(self):
    self.path = self._create_temp_file(suffix='.gz')
    sink = TextSink(self.path)
    self._write_lines(sink, self.lines)

    with gzip.GzipFile(self.path, 'rb') as f:
      self.assertEqual(f.read().splitlines(), self.lines)

  def test_write_gzip_file_empty(self):
    sink = TextSink(self.path, compression_type=CompressionTypes.GZIP)
    self._write_lines(sink, [])

    with gzip.GzipFile(self.path, 'rb') as f:
      self.assertEqual(f.read().splitlines(), [])

  def test_write_deflate_file(self):
    sink = TextSink(self.path, compression_type=CompressionTypes.DEFLATE)
    self._write_lines(sink, self.lines)

    with open(self.path, 'rb') as f:
      self.assertEqual(zlib.decompress(f.read()).splitlines(), self.lines)

  def test_write_deflate_file_auto(self):
    self.path = self._create_temp_file(suffix='.deflate')
    sink = TextSink(self.path)
    self._write_lines(sink, self.lines)

    with open(self.path, 'rb') as f:
      self.assertEqual(zlib.decompress(f.read()).splitlines(), self.lines)

  def test_write_deflate_file_empty(self):
    sink = TextSink(self.path, compression_type=CompressionTypes.DEFLATE)
    self._write_lines(sink, [])

    with open(self.path, 'rb') as f:
      self.assertEqual(zlib.decompress(f.read()).splitlines(), [])

  def test_write_text_file_with_header(self):
    header = b'header1\nheader2'
    sink = TextSink(self.path, header=header)
    self._write_lines(sink, self.lines)

    with open(self.path, 'rb') as f:
      self.assertEqual(f.read().splitlines(), header.splitlines() + self.lines)

  def test_write_text_file_with_footer(self):
    footer = b'footer1\nfooter2'
    sink = TextSink(self.path, footer=footer)
    self._write_lines(sink, self.lines)

    with open(self.path, 'rb') as f:
      self.assertEqual(f.read().splitlines(), self.lines + footer.splitlines())

  def test_write_text_file_empty_with_header(self):
    header = b'header1\nheader2'
    sink = TextSink(self.path, header=header)
    self._write_lines(sink, [])

    with open(self.path, 'rb') as f:
      self.assertEqual(f.read().splitlines(), header.splitlines())

  def test_write_pipeline(self):
    with TestPipeline() as pipeline:
      pcoll = pipeline | beam.core.Create(self.lines)
      pcoll | 'Write' >> WriteToText(self.path)  # pylint: disable=expression-not-assigned

    read_result = []
    for file_name in glob.glob(self.path + '*'):
      with open(file_name, 'rb') as f:
        read_result.extend(f.read().splitlines())

    self.assertEqual(sorted(read_result), sorted(self.lines))

  def test_write_pipeline_non_globalwindow_input(self):
    with TestPipeline() as p:
      _ = (
          p
          | beam.core.Create(self.lines)
          | beam.WindowInto(beam.transforms.window.FixedWindows(1))
          | 'Write' >> WriteToText(self.path))

    read_result = []
    for file_name in glob.glob(self.path + '*'):
      with open(file_name, 'rb') as f:
        read_result.extend(f.read().splitlines())

    self.assertEqual(sorted(read_result), sorted(self.lines))

  def test_write_pipeline_auto_compression(self):
    with TestPipeline() as pipeline:
      pcoll = pipeline | beam.core.Create(self.lines)
      pcoll | 'Write' >> WriteToText(self.path, file_name_suffix='.gz')  # pylint: disable=expression-not-assigned

    read_result = []
    for file_name in glob.glob(self.path + '*'):
      with gzip.GzipFile(file_name, 'rb') as f:
        read_result.extend(f.read().splitlines())

    self.assertEqual(sorted(read_result), sorted(self.lines))

  def test_write_pipeline_auto_compression_unsharded(self):
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Create' >> beam.core.Create(self.lines)
      pcoll | 'Write' >> WriteToText(  # pylint: disable=expression-not-assigned
          self.path + '.gz',
          shard_name_template='')

    read_result = []
    for file_name in glob.glob(self.path + '*'):
      with gzip.GzipFile(file_name, 'rb') as f:
        read_result.extend(f.read().splitlines())

    self.assertEqual(sorted(read_result), sorted(self.lines))

  def test_write_pipeline_header(self):
    with TestPipeline() as pipeline:
      pcoll = pipeline | 'Create' >> beam.core.Create(self.lines)
      header_text = 'foo'
      pcoll | 'Write' >> WriteToText(  # pylint: disable=expression-not-assigned
          self.path + '.gz',
          shard_name_template='',
          header=header_text)

    read_result = []
    for file_name in glob.glob(self.path + '*'):
      with gzip.GzipFile(file_name, 'rb') as f:
        read_result.extend(f.read().splitlines())
    # header_text is automatically encoded in WriteToText
    self.assertEqual(read_result[0], header_text.encode('utf-8'))
    self.assertEqual(sorted(read_result[1:]), sorted(self.lines))

  def test_write_pipeline_footer(self):
    with TestPipeline() as pipeline:
      footer_text = 'footer'
      pcoll = pipeline | beam.core.Create(self.lines)
      pcoll | 'Write' >> WriteToText(   # pylint: disable=expression-not-assigned
        self.path,
        footer=footer_text)

    read_result = []
    for file_name in glob.glob(self.path + '*'):
      with open(file_name, 'rb') as f:
        read_result.extend(f.read().splitlines())

    self.assertEqual(sorted(read_result[:-1]), sorted(self.lines))
    self.assertEqual(read_result[-1], footer_text.encode('utf-8'))

  def test_write_empty(self):
    with TestPipeline() as p:
      # pylint: disable=expression-not-assigned
      p | beam.core.Create([]) | WriteToText(self.path)

    outputs = glob.glob(self.path + '*')
    self.assertEqual(len(outputs), 1)
    with open(outputs[0], 'rb') as f:
      self.assertEqual(list(f.read().splitlines()), [])

  def test_write_empty_skipped(self):
    with TestPipeline() as p:
      # pylint: disable=expression-not-assigned
      p | beam.core.Create([]) | WriteToText(self.path, skip_if_empty=True)

    outputs = list(glob.glob(self.path + '*'))
    self.assertEqual(outputs, [])

  def test_write_max_records_per_shard(self):
    records_per_shard = 13
    lines = [str(i).encode('utf-8') for i in range(100)]
    with TestPipeline() as p:
      # pylint: disable=expression-not-assigned
      p | beam.core.Create(lines) | WriteToText(
          self.path, max_records_per_shard=records_per_shard)

    read_result = []
    for file_name in glob.glob(self.path + '*'):
      with open(file_name, 'rb') as f:
        shard_lines = list(f.read().splitlines())
        self.assertLessEqual(len(shard_lines), records_per_shard)
        read_result.extend(shard_lines)
    self.assertEqual(sorted(read_result), sorted(lines))

  def test_write_max_bytes_per_shard(self):
    bytes_per_shard = 300
    max_len = 100
    lines = [b'x' * i for i in range(max_len)]
    header = b'a' * 20
    footer = b'b' * 30
    with TestPipeline() as p:
      # pylint: disable=expression-not-assigned
      p | beam.core.Create(lines) | WriteToText(
          self.path,
          header=header,
          footer=footer,
          max_bytes_per_shard=bytes_per_shard)

    read_result = []
    for file_name in glob.glob(self.path + '*'):
      with open(file_name, 'rb') as f:
        contents = f.read()
        self.assertLessEqual(
            len(contents), bytes_per_shard + max_len + len(footer) + 2)
        shard_lines = list(contents.splitlines())
        self.assertEqual(shard_lines[0], header)
        self.assertEqual(shard_lines[-1], footer)
        read_result.extend(shard_lines[1:-1])
    self.assertEqual(sorted(read_result), sorted(lines))


class CsvTest(unittest.TestCase):
  def test_csv_read_write(self):
    records = [beam.Row(a='str', b=ix) for ix in range(3)]
    with tempfile.TemporaryDirectory() as dest:
      with TestPipeline() as p:
        # pylint: disable=expression-not-assigned
        p | beam.Create(records) | beam.io.WriteToCsv(os.path.join(dest, 'out'))
      with TestPipeline() as p:
        pcoll = (
            p
            | beam.io.ReadFromCsv(os.path.join(dest, 'out*'))
            | beam.Map(lambda t: beam.Row(**dict(zip(type(t)._fields, t)))))

        assert_that(pcoll, equal_to(records))

  def test_csv_read_with_filename(self):
    records = [beam.Row(a='str', b=ix) for ix in range(3)]
    with tempfile.TemporaryDirectory() as dest:
      file_path = os.path.join(dest, 'out.csv')
      with TestPipeline() as p:
        # pylint: disable=expression-not-assigned
        p | beam.Create(records) | beam.io.WriteToCsv(file_path)
      with TestPipeline() as p:
        pcoll = (
            p
            | beam.io.ReadFromCsv(
                file_path + '*', filename_column='source_filename')
            | beam.Map(lambda t: beam.Row(**dict(zip(type(t)._fields, t)))))

        # Get the sharded file name
        files = glob.glob(file_path + '*')
        self.assertEqual(len(files), 1)
        sharded_file_path = files[0]

        expected = [
            beam.Row(a=r.a, b=r.b, source_filename=sharded_file_path)
            for r in records
        ]
        assert_that(pcoll, equal_to(expected))

  def test_non_utf8_csv_read_write(self):
    content = b"\xe0,\xe1,\xe2\n0,1,2\n1,2,3\n"

    with tempfile.TemporaryDirectory() as dest:
      input_fn = os.path.join(dest, 'input.csv')
      with open(input_fn, 'wb') as f:
        f.write(content)

      with TestPipeline() as p:
        r1 = (
            p
            | 'Read' >> beam.io.ReadFromCsv(input_fn, encoding="latin1")
            | 'ToDict' >> beam.Map(lambda x: x._asdict()))
        assert_that(
            r1,
            equal_to([{
                "\u00e0": 0, "\u00e1": 1, "\u00e2": 2
            }, {
                "\u00e0": 1, "\u00e1": 2, "\u00e2": 3
            }]))

      with TestPipeline() as p:
        _ = (
            p
            | 'Read' >> beam.io.ReadFromCsv(input_fn, encoding="latin1")
            | 'Write' >> beam.io.WriteToCsv(
                os.path.join(dest, 'out'), encoding="latin1"))

      with TestPipeline() as p:
        r2 = (
            p
            | 'Read' >> beam.io.ReadFromCsv(
                os.path.join(dest, 'out*'), encoding="latin1")
            | 'ToDict' >> beam.Map(lambda x: x._asdict()))
        assert_that(
            r2,
            equal_to([{
                "\u00e0": 0, "\u00e1": 1, "\u00e2": 2
            }, {
                "\u00e0": 1, "\u00e1": 2, "\u00e2": 3
            }]))


class JsonTest(unittest.TestCase):
  def test_json_read_write(self):
    records = [beam.Row(a='str', b=ix) for ix in range(3)]
    with tempfile.TemporaryDirectory() as dest:
      with TestPipeline() as p:
        # pylint: disable=expression-not-assigned
        p | beam.Create(records) | beam.io.WriteToJson(
            os.path.join(dest, 'out'))
      with TestPipeline() as p:
        pcoll = (
            p
            | beam.io.ReadFromJson(os.path.join(dest, 'out*'))
            | beam.Map(lambda t: beam.Row(**dict(zip(type(t)._fields, t)))))

        assert_that(pcoll, equal_to(records))

  def test_numeric_strings_preserved(self):
    records = [
        beam.Row(
            as_string=str(ix),
            as_float_string=str(float(ix)),
            as_int=ix,
            as_float=float(ix)) for ix in range(3)
    ]
    with tempfile.TemporaryDirectory() as dest:
      with TestPipeline() as p:
        # pylint: disable=expression-not-assigned
        p | beam.Create(records) | beam.io.WriteToJson(
            os.path.join(dest, 'out'))
      with TestPipeline() as p:
        pcoll = (
            p
            | beam.io.ReadFromJson(os.path.join(dest, 'out*'))
            | beam.Map(lambda t: beam.Row(**dict(zip(type(t)._fields, t)))))

        assert_that(pcoll, equal_to(records))

        # This test should be redundant as Python equality does not equate
        # numeric values with their string representations, but this is much
        # more explicit about what we're asserting here.
        def check_types(element):
          for a, b in zip(element, records[0]):
            assert type(a) == type(b), (a, b, type(a), type(b))

        _ = pcoll | beam.Map(check_types)


class GenerateEvent(beam.PTransform):
  @staticmethod
  def sample_data():
    return GenerateEvent()

  def expand(self, input):
    elemlist = [{'age': 10}, {'age': 20}, {'age': 30}]
    elem = elemlist
    return (
        input
        | TestStream().add_elements(
            elements=elem,
            event_timestamp=datetime(
                2021, 3, 1, 0, 0, 1, 0,
                tzinfo=pytz.UTC).timestamp()).add_elements(
                    elements=elem,
                    event_timestamp=datetime(
                        2021, 3, 1, 0, 0, 2, 0,
                        tzinfo=pytz.UTC).timestamp()).add_elements(
                            elements=elem,
                            event_timestamp=datetime(
                                2021, 3, 1, 0, 0, 3, 0,
                                tzinfo=pytz.UTC).timestamp()).add_elements(
                                    elements=elem,
                                    event_timestamp=datetime(
                                        2021, 3, 1, 0, 0, 4, 0,
                                        tzinfo=pytz.UTC).timestamp()).
        advance_watermark_to(
            datetime(2021, 3, 1, 0, 0, 5, 0,
                     tzinfo=pytz.UTC).timestamp()).add_elements(
                         elements=elem,
                         event_timestamp=datetime(
                             2021, 3, 1, 0, 0, 5, 0,
                             tzinfo=pytz.UTC).timestamp()).
        add_elements(
            elements=elem,
            event_timestamp=datetime(
                2021, 3, 1, 0, 0, 6,
                0, tzinfo=pytz.UTC).timestamp()).add_elements(
                    elements=elem,
                    event_timestamp=datetime(
                        2021, 3, 1, 0, 0, 7, 0,
                        tzinfo=pytz.UTC).timestamp()).add_elements(
                            elements=elem,
                            event_timestamp=datetime(
                                2021, 3, 1, 0, 0, 8, 0,
                                tzinfo=pytz.UTC).timestamp()).add_elements(
                                    elements=elem,
                                    event_timestamp=datetime(
                                        2021, 3, 1, 0, 0, 9, 0,
                                        tzinfo=pytz.UTC).timestamp()).
        advance_watermark_to(
            datetime(2021, 3, 1, 0, 0, 10, 0,
                     tzinfo=pytz.UTC).timestamp()).add_elements(
                         elements=elem,
                         event_timestamp=datetime(
                             2021, 3, 1, 0, 0, 10, 0,
                             tzinfo=pytz.UTC).timestamp()).add_elements(
                                 elements=elem,
                                 event_timestamp=datetime(
                                     2021, 3, 1, 0, 0, 11, 0,
                                     tzinfo=pytz.UTC).timestamp()).
        add_elements(
            elements=elem,
            event_timestamp=datetime(
                2021, 3, 1, 0, 0, 12, 0,
                tzinfo=pytz.UTC).timestamp()).add_elements(
                    elements=elem,
                    event_timestamp=datetime(
                        2021, 3, 1, 0, 0, 13, 0,
                        tzinfo=pytz.UTC).timestamp()).add_elements(
                            elements=elem,
                            event_timestamp=datetime(
                                2021, 3, 1, 0, 0, 14, 0,
                                tzinfo=pytz.UTC).timestamp()).
        advance_watermark_to(
            datetime(2021, 3, 1, 0, 0, 15, 0,
                     tzinfo=pytz.UTC).timestamp()).add_elements(
                         elements=elem,
                         event_timestamp=datetime(
                             2021, 3, 1, 0, 0, 15, 0,
                             tzinfo=pytz.UTC).timestamp()).add_elements(
                                 elements=elem,
                                 event_timestamp=datetime(
                                     2021, 3, 1, 0, 0, 16, 0,
                                     tzinfo=pytz.UTC).timestamp()).
        add_elements(
            elements=elem,
            event_timestamp=datetime(
                2021, 3, 1, 0, 0, 17, 0,
                tzinfo=pytz.UTC).timestamp()).add_elements(
                    elements=elem,
                    event_timestamp=datetime(
                        2021, 3, 1, 0, 0, 18, 0,
                        tzinfo=pytz.UTC).timestamp()).add_elements(
                            elements=elem,
                            event_timestamp=datetime(
                                2021, 3, 1, 0, 0, 19, 0,
                                tzinfo=pytz.UTC).timestamp()).
        advance_watermark_to(
            datetime(2021, 3, 1, 0, 0, 20, 0,
                     tzinfo=pytz.UTC).timestamp()).add_elements(
                         elements=elem,
                         event_timestamp=datetime(
                             2021, 3, 1, 0, 0, 20, 0,
                             tzinfo=pytz.UTC).timestamp()).advance_watermark_to(
                                 datetime(
                                     2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC).
                                 timestamp()).advance_watermark_to_infinity())


class WriteStreamingTest(unittest.TestCase):
  def setUp(self):
    super().setUp()
    self.tempdir = tempfile.mkdtemp()

  def tearDown(self):
    if os.path.exists(self.tempdir):
      shutil.rmtree(self.tempdir)

  def test_write_streaming_2_shards_default_shard_name_template(
      self, num_shards=2):
    with TestPipeline() as p:
      output = (p | GenerateEvent.sample_data())
      #TextIO
      output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText(
          file_path_prefix=self.tempdir + "/ouput_WriteToText",
          file_name_suffix=".txt",
          num_shards=num_shards,
          triggering_frequency=60)
      _ = output2 | 'LogElements after WriteToText' >> LogElements(
          prefix='after WriteToText ', with_window=True, level=logging.INFO)

    # Regex to match the expected windowed file pattern
    # Example:
    # ouput_WriteToText-[1614556800.0, 1614556805.0)-00000-of-00002.txt
    # It captures: window_interval, shard_num, total_shards
    pattern_string = (
        r'.*-\[(?P<window_start>[\d\.]+), '
        r'(?P<window_end>[\d\.]+|Infinity)\)-'
        r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.txt$')
    pattern = re.compile(pattern_string)
    file_names = []
    for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'):
      match = pattern.match(file_name)
      self.assertIsNotNone(
          match, f"File name {file_name} did not match expected pattern.")
      if match:
        file_names.append(file_name)
    print("Found files matching expected pattern:", file_names)
    self.assertEqual(
        len(file_names),
        num_shards,
        "expected %d files, but got: %d" % (num_shards, len(file_names)))

  def test_write_streaming_2_shards_default_shard_name_template_windowed_pcoll(
      self, num_shards=2):
    with TestPipeline() as p:
      output = (
          p | GenerateEvent.sample_data()
          | 'User windowing' >> beam.transforms.core.WindowInto(
              beam.transforms.window.FixedWindows(10),
              trigger=beam.transforms.trigger.AfterWatermark(),
              accumulation_mode=beam.transforms.trigger.AccumulationMode.
              DISCARDING,
              allowed_lateness=beam.utils.timestamp.Duration(seconds=0)))
      #TextIO
      output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText(
          file_path_prefix=self.tempdir + "/ouput_WriteToText",
          file_name_suffix=".txt",
          num_shards=num_shards,
      )
      _ = output2 | 'LogElements after WriteToText' >> LogElements(
          prefix='after WriteToText ', with_window=True, level=logging.INFO)

    # Regex to match the expected windowed file pattern
    # Example:
    # ouput_WriteToText-[1614556800.0, 1614556805.0)-00000-of-00002.txt
    # It captures: window_interval, shard_num, total_shards
    pattern_string = (
        r'.*-\[(?P<window_start>[\d\.]+), '
        r'(?P<window_end>[\d\.]+|Infinity)\)-'
        r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.txt$')
    pattern = re.compile(pattern_string)
    file_names = []
    for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'):
      match = pattern.match(file_name)
      self.assertIsNotNone(
          match, f"File name {file_name} did not match expected pattern.")
      if match:
        file_names.append(file_name)
    print("Found files matching expected pattern:", file_names)
    self.assertEqual(
        len(file_names),
        num_shards * 3,  #25s of data covered by 3 10s windows
        "expected %d files, but got: %d" % (num_shards * 3, len(file_names)))

  def test_write_streaming_undef_shards_default_shard_name_template_windowed_pcoll(  # pylint: disable=line-too-long
      self):
    with TestPipeline() as p:
      output = (
          p | GenerateEvent.sample_data()
          | 'User windowing' >> beam.transforms.core.WindowInto(
              beam.transforms.window.FixedWindows(10),
              trigger=beam.transforms.trigger.AfterWatermark(),
              accumulation_mode=beam.transforms.trigger.AccumulationMode.
              DISCARDING,
              allowed_lateness=beam.utils.timestamp.Duration(seconds=0)))
      #TextIO
      output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText(
          file_path_prefix=self.tempdir + "/ouput_WriteToText",
          file_name_suffix=".txt",
          num_shards=0,
      )
      _ = output2 | 'LogElements after WriteToText' >> LogElements(
          prefix='after WriteToText ', with_window=True, level=logging.INFO)

    # Regex to match the expected windowed file pattern
    # Example:
    # ouput_WriteToText-[1614556800.0, 1614556805.0)-00000-of-00002.txt
    # It captures: window_interval, shard_num, total_shards
    pattern_string = (
        r'.*-\[(?P<window_start>[\d\.]+), '
        r'(?P<window_end>[\d\.]+|Infinity)\)-'
        r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.txt$')
    pattern = re.compile(pattern_string)
    file_names = []
    for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'):
      match = pattern.match(file_name)
      self.assertIsNotNone(
          match, f"File name {file_name} did not match expected pattern.")
      if match:
        file_names.append(file_name)
    print("Found files matching expected pattern:", file_names)
    self.assertGreaterEqual(
        len(file_names),
        1 * 3,  #25s of data covered by 3 10s windows
        "expected %d files, but got: %d" % (1 * 3, len(file_names)))

  def test_write_streaming_undef_shards_default_shard_name_template_windowed_pcoll_and_trig_freq(  # pylint: disable=line-too-long
      self):
    with TestPipeline() as p:
      output = (
          p | GenerateEvent.sample_data()
          | 'User windowing' >> beam.transforms.core.WindowInto(
              beam.transforms.window.FixedWindows(60),
              trigger=beam.transforms.trigger.AfterWatermark(),
              accumulation_mode=beam.transforms.trigger.AccumulationMode.
              DISCARDING,
              allowed_lateness=beam.utils.timestamp.Duration(seconds=0)))
      #TextIO
      output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText(
          file_path_prefix=self.tempdir + "/ouput_WriteToText",
          file_name_suffix=".txt",
          num_shards=0,
          triggering_frequency=10,
      )
      _ = output2 | 'LogElements after WriteToText' >> LogElements(
          prefix='after WriteToText ', with_window=True, level=logging.INFO)

    # Regex to match the expected windowed file pattern
    # Example:
    # ouput_WriteToText-[1614556800.0, 1614556805.0)-00000-of-00002.txt
    # It captures: window_interval, shard_num, total_shards
    pattern_string = (
        r'.*-\[(?P<window_start>[\d\.]+), '
        r'(?P<window_end>[\d\.]+|Infinity)\)-'
        r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.txt$')
    pattern = re.compile(pattern_string)
    file_names = []
    for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'):
      match = pattern.match(file_name)
      self.assertIsNotNone(
          match, f"File name {file_name} did not match expected pattern.")
      if match:
        file_names.append(file_name)
    print("Found files matching expected pattern:", file_names)
    self.assertGreaterEqual(
        len(file_names),
        1 * 3,  #25s of data covered by 3 10s windows
        "expected %d files, but got: %d" % (1 * 3, len(file_names)))

  def test_write_streaming_undef_shards_default_shard_name_template_global_window_pcoll(  # pylint: disable=line-too-long
      self):
    with TestPipeline() as p:
      output = (p | GenerateEvent.sample_data())
      #TextIO
      output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText(
          file_path_prefix=self.tempdir + "/ouput_WriteToText",
          file_name_suffix=".txt",
          num_shards=0,  #0 means undef nb of shards, same as omitted/default
          triggering_frequency=60,
      )
      _ = output2 | 'LogElements after WriteToText' >> LogElements(
          prefix='after WriteToText ', with_window=True, level=logging.INFO)

    # Regex to match the expected windowed file pattern
    # Example:
    # ouput_WriteToText-[1614556800.0, 1614556805.0)-00000-of-00002.txt
    # It captures: window_interval, shard_num, total_shards
    pattern_string = (
        r'.*-\[(?P<window_start>[\d\.]+), '
        r'(?P<window_end>[\d\.]+|Infinity)\)-'
        r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.txt$')
    pattern = re.compile(pattern_string)
    file_names = []
    for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'):
      match = pattern.match(file_name)
      self.assertIsNotNone(
          match, f"File name {file_name} did not match expected pattern.")
      if match:
        file_names.append(file_name)
    print("Found files matching expected pattern:", file_names)
    self.assertGreaterEqual(
        len(file_names),
        1,  #25s of data covered by 60s windows
        "expected %d files, but got: %d" % (1, len(file_names)))

  def test_write_streaming_2_shards_custom_shard_name_template(
      self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'):
    with TestPipeline() as p:
      output = (p | GenerateEvent.sample_data())
      #TextIO
      output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText(
          file_path_prefix=self.tempdir + "/ouput_WriteToText",
          file_name_suffix=".txt",
          shard_name_template=shard_name_template,
          num_shards=num_shards,
          triggering_frequency=60,
      )
      _ = output2 | 'LogElements after WriteToText' >> LogElements(
          prefix='after WriteToText ', with_window=True, level=logging.INFO)

    # Regex to match the expected windowed file pattern
    # Example:
    # ouput_WriteToText-[2021-03-01T00-00-00, 2021-03-01T00-01-00)-
    #   00000-of-00002.txt
    # It captures: window_interval, shard_num, total_shards
    pattern_string = (
        r'.*-\[(?P<window_start>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), '
        r'(?P<window_end>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-'
        r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.txt$')
    pattern = re.compile(pattern_string)
    file_names = []
    for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'):
      match = pattern.match(file_name)
      self.assertIsNotNone(
          match, f"File name {file_name} did not match expected pattern.")
      if match:
        file_names.append(file_name)
    print("Found files matching expected pattern:", file_names)
    self.assertEqual(
        len(file_names),
        num_shards,
        "expected %d files, but got: %d" % (num_shards, len(file_names)))

  def test_write_streaming_2_shards_custom_shard_name_template_5s_window(
      self,
      num_shards=2,
      shard_name_template='-V-SSSSS-of-NNNNN',
      triggering_frequency=5):
    with TestPipeline() as p:
      output = (p | GenerateEvent.sample_data())
      #TextIO
      output2 = output | 'TextIO WriteToText' >> beam.io.WriteToText(
          file_path_prefix=self.tempdir + "/ouput_WriteToText",
          file_name_suffix=".txt",
          shard_name_template=shard_name_template,
          num_shards=num_shards,
          triggering_frequency=triggering_frequency,
      )
      _ = output2 | 'LogElements after WriteToText' >> LogElements(
          prefix='after WriteToText ', with_window=True, level=logging.INFO)

    # Regex to match the expected windowed file pattern
    # Example:
    # ouput_WriteToText-[2021-03-01T00-00-00, 2021-03-01T00-01-00)-
    #   00000-of-00002.txt
    # It captures: window_interval, shard_num, total_shards
    pattern_string = (
        r'.*-\[(?P<window_start>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), '
        r'(?P<window_end>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-'
        r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.txt$')
    pattern = re.compile(pattern_string)
    file_names = []
    for file_name in glob.glob(self.tempdir + '/ouput_WriteToText*'):
      match = pattern.match(file_name)
      self.assertIsNotNone(
          match, f"File name {file_name} did not match expected pattern.")
      if match:
        file_names.append(file_name)
    print("Found files matching expected pattern:", file_names)
    # for 5s window size, the input should be processed by 5 windows with
    # 2 shards per window
    self.assertEqual(
        len(file_names),
        10,
        "expected %d files, but got: %d" % (num_shards, len(file_names)))


if __name__ == '__main__':
  logging.getLogger().setLevel(logging.INFO)
  unittest.main()
